package com.tpvlog.rpc.provider;

import com.tpvlog.rpc.core.ServiceMeta;
import com.tpvlog.rpc.core.utils.RpcServiceHelper;
import com.tpvlog.rpc.protocol.codec.RpcDecoder;
import com.tpvlog.rpc.protocol.codec.RpcEncoder;
import com.tpvlog.rpc.protocol.handler.RpcRequestHandler;
import com.tpvlog.rpc.provider.annotation.RpcService;
import com.tpvlog.rpc.registry.RegistryService;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.config.BeanPostProcessor;

import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * RPC服务初始化器
 */
public class RpcProviderInitializer implements InitializingBean, BeanPostProcessor {
    private static final Logger LOG = LoggerFactory.getLogger(RpcProviderInitializer.class);

    private final Integer serverPort;
    private final String serverAddress;
    private final RegistryService serviceRegistry;
    /**
     * 服务实例缓存<服务Key，服务实例对象>
     */
    private final Map<String, Object> rpcServiceMap = new ConcurrentHashMap<>();

    public RpcProviderInitializer(Integer serverPort, RegistryService serviceRegistry) {
        try {
            this.serverAddress = InetAddress.getLocalHost().getHostAddress();
        } catch (UnknownHostException e) {
            throw new RuntimeException("unknown host", e);
        }
        this.serverPort = serverPort;
        this.serviceRegistry = serviceRegistry;
    }

    @Override
    public void afterPropertiesSet() {
        new Thread(() -> {
            try {
                startRpcServer();
            } catch (Exception e) {
                LOG.error("start rpc server error.", e);
            }
        }).start();
    }

    private void startRpcServer() throws Exception {
        // 基于Netty启动RPC服务Server
        EventLoopGroup boss = new NioEventLoopGroup();
        EventLoopGroup worker = new NioEventLoopGroup();
        try {
            ServerBootstrap bootstrap = new ServerBootstrap();
            bootstrap.group(boss, worker)
                    .channel(NioServerSocketChannel.class)
                    .childHandler(new ChannelInitializer<SocketChannel>() {
                        @Override
                        protected void initChannel(SocketChannel socketChannel) {
                            socketChannel.pipeline()
                                    .addLast(new RpcEncoder())
                                    .addLast(new RpcDecoder())
                                    .addLast(new RpcRequestHandler(rpcServiceMap));
                        }
                    })
                    .childOption(ChannelOption.SO_KEEPALIVE, true);

            ChannelFuture channelFuture = bootstrap.bind(this.serverAddress, this.serverPort).sync();
            LOG.info("server addr {} started on port {}", this.serverAddress, this.serverPort);
            channelFuture.channel().closeFuture().sync();
        } finally {
            boss.shutdownGracefully();
            worker.shutdownGracefully();
        }
    }

    @Override
    public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
        // 如果类注解了@RpcService,则执行服务注册
        RpcService rpcService = bean.getClass().getAnnotation(RpcService.class);
        if (rpcService != null) {
            // 服务类的全限定名
            String serviceName = rpcService.service().getName();
            // 服务版本号
            String serviceVersion = rpcService.version();

            try {
                // 创建服务元数据对象
                ServiceMeta serviceMeta = new ServiceMeta();
                serviceMeta.setAddress(serverAddress);
                serviceMeta.setPort(serverPort);
                serviceMeta.setService(serviceName);
                serviceMeta.setVersion(serviceVersion);
                // 注册服务
                serviceRegistry.register(serviceMeta);
                // 缓存服务
                String serviceKey = RpcServiceHelper.buildServiceKey(serviceMeta.getService(), serviceMeta.getVersion());
                rpcServiceMap.put(serviceKey, bean);
            } catch (Exception e) {
                LOG.error("failed to register service {}#{}", serviceName, serviceVersion, e);
            }
        }
        return bean;
    }
}
